package com.jstarcraft.ai.jsat.io;

import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.io.Reader;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import com.jstarcraft.ai.jsat.DataSet;
import com.jstarcraft.ai.jsat.DataStore;
import com.jstarcraft.ai.jsat.RowMajorStore;
import com.jstarcraft.ai.jsat.classifiers.CategoricalData;
import com.jstarcraft.ai.jsat.classifiers.ClassificationDataSet;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.datatransform.DenseSparceTransform;
import com.jstarcraft.ai.jsat.linear.IndexValue;
import com.jstarcraft.ai.jsat.linear.SparseVector;
import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.regression.RegressionDataSet;
import com.jstarcraft.ai.jsat.utils.StringUtils;

import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.ints.IntArrayList;

/**
 * Loads a LIBSVM data file into a {@link DataSet}. LIVSM files do not indicate
 * whether or not the target variable is supposed to be numerical or
 * categorical, so two different loading methods are provided. For a LIBSVM file
 * to be loaded correctly, it must match the LIBSVM spec without extensions.
 * <br>
 * <br>
 * Each line should begin with a numeric value. This is either a regression
 * target or a class label. <br>
 * Then, for each non zero value in the data set, a space should precede an
 * integer value index starting from 1 followed by a colon ":" followed by a
 * numeric feature value. <br>
 * The single space at the beginning should be the only space. There should be
 * no double spaces in the file. <br>
 * <br>
 * LIBSVM files do not explicitly specify the length of data vectors. This can
 * be problematic if loading a testing and training data set, if the data sets
 * do not include the same highest index as a non-zero value, the data sets will
 * have incompatible vector lengths. To resolve this issue, use the loading
 * methods that include the optional {@code vectorLength} parameter to specify
 * the length before hand.
 * 
 * @author Edward Raff
 */
public class LIBSVMLoader {
    private static boolean fastLoad = true;

    private LIBSVMLoader() {
    }

    /*
     * LIBSVM format is sparse <VAL> <1 based Index>:<Value>
     * 
     */

    /**
     * Loads a new regression data set from a LIBSVM file, assuming the label is a
     * numeric target value to predict
     * 
     * @param file the file to load
     * @return a regression data set
     * @throws FileNotFoundException if the file was not found
     * @throws IOException           if an error occurred reading the input stream
     */
    public static RegressionDataSet loadR(File file) throws FileNotFoundException, IOException {
        return loadR(file, 0.5);
    }

    /**
     * Loads a new regression data set from a LIBSVM file, assuming the label is a
     * numeric target value to predict
     * 
     * @param file        the file to load
     * @param sparseRatio the fraction of non zero values to qualify a data point as
     *                    sparse
     * @return a regression data set
     * @throws FileNotFoundException if the file was not found
     * @throws IOException           if an error occurred reading the input stream
     */
    public static RegressionDataSet loadR(File file, double sparseRatio) throws FileNotFoundException, IOException {
        return loadR(file, sparseRatio, -1);
    }

    /**
     * Loads a new regression data set from a LIBSVM file, assuming the label is a
     * numeric target value to predict
     * 
     * @param file         the file to load
     * @param sparseRatio  the fraction of non zero values to qualify a data point
     *                     as sparse
     * @param vectorLength the pre-determined length of each vector. If given a
     *                     negative value, the largest non-zero index observed in
     *                     the data will be used as the length.
     * @return a regression data set
     * @throws FileNotFoundException if the file was not found
     * @throws IOException           if an error occurred reading the input stream
     */
    public static RegressionDataSet loadR(File file, double sparseRatio, int vectorLength) throws FileNotFoundException, IOException {
        return loadR(new FileReader(file), sparseRatio, vectorLength);
    }

    /**
     * Loads a new regression data set from a LIBSVM file, assuming the label is a
     * numeric target value to predict
     * 
     * @param isr         the input stream for the file to load
     * @param sparseRatio the fraction of non zero values to qualify a data point as
     *                    sparse
     * @return a regression data set
     * @throws IOException if an error occurred reading the input stream
     */
    public static RegressionDataSet loadR(InputStreamReader isr, double sparseRatio) throws IOException {
        return loadR(isr, sparseRatio, -1);
    }

    /**
     * Loads a new regression data set from a LIBSVM file, assuming the label is a
     * numeric target value to predict.
     * 
     * @param reader       the reader for the file to load
     * @param sparseRatio  the fraction of non zero values to qualify a data point
     *                     as sparse
     * @param vectorLength the pre-determined length of each vector. If given a
     *                     negative value, the largest non-zero index observed in
     *                     the data will be used as the length.
     * @return a regression data set
     * @throws IOException
     */
    public static RegressionDataSet loadR(Reader reader, double sparseRatio, int vectorLength) throws IOException {
        return loadR(reader, sparseRatio, vectorLength, DataStore.DEFAULT_STORE);
    }

    /**
     * Loads a new regression data set from a LIBSVM file, assuming the label is a
     * numeric target value to predict.
     * 
     * @param reader       the reader for the file to load
     * @param sparseRatio  the fraction of non zero values to qualify a data point
     *                     as sparse
     * @param vectorLength the pre-determined length of each vector. If given a
     *                     negative value, the largest non-zero index observed in
     *                     the data will be used as the length.
     * @param store        the type of store to use for data
     * @return a regression data set
     * @throws IOException
     */
    public static RegressionDataSet loadR(Reader reader, double sparseRatio, int vectorLength, DataStore store) throws IOException {
        return (RegressionDataSet) loadG(reader, sparseRatio, vectorLength, false, store);
    }

    /**
     * Loads a new classification data set from a LIBSVM file, assuming the label is
     * a nominal target value
     * 
     * @param file the file to load
     * @return a classification data set
     * @throws FileNotFoundException if the file was not found
     * @throws IOException           if an error occurred reading the input stream
     */
    public static ClassificationDataSet loadC(File file) throws FileNotFoundException, IOException {
        return loadC(new FileReader(file), 0.5);
    }

    /**
     * Loads a new classification data set from a LIBSVM file, assuming the label is
     * a nominal target value
     * 
     * @param file        the file to load
     * @param sparseRatio the fraction of non zero values to qualify a data point as
     *                    sparse
     * @return a classification data set
     * @throws FileNotFoundException if the file was not found
     * @throws IOException           if an error occurred reading the input stream
     */
    public static ClassificationDataSet loadC(File file, double sparseRatio) throws FileNotFoundException, IOException {
        return loadC(file, sparseRatio, -1);
    }

    /**
     * Loads a new classification data set from a LIBSVM file, assuming the label is
     * a nominal target value
     * 
     * @param file         the file to load
     * @param sparseRatio  the fraction of non zero values to qualify a data point
     *                     as sparse
     * @param vectorLength the pre-determined length of each vector. If given a
     *                     negative value, the largest non-zero index observed in
     *                     the data will be used as the length.
     * @return a classification data set
     * @throws FileNotFoundException if the file was not found
     * @throws IOException           if an error occurred reading the input stream
     */
    public static ClassificationDataSet loadC(File file, double sparseRatio, int vectorLength) throws FileNotFoundException, IOException {
        return loadC(new FileReader(file), sparseRatio, vectorLength);
    }

    /**
     * Loads a new classification data set from a LIBSVM file, assuming the label is
     * a nominal target value
     * 
     * @param isr         the input stream for the file to load
     * @param sparseRatio the fraction of non zero values to qualify a data point as
     *                    sparse
     * @return a classification data set
     * @throws IOException if an error occurred reading the input stream
     */
    public static ClassificationDataSet loadC(InputStreamReader isr, double sparseRatio) throws IOException {
        return loadC(isr, sparseRatio, -1);
    }

    /**
     * Loads a new classification data set from a LIBSVM file, assuming the label is
     * a nominal target value
     * 
     * @param reader       the input stream for the file to load
     * @param sparseRatio  the fraction of non zero values to qualify a data point
     *                     as sparse
     * @param vectorLength the pre-determined length of each vector. If given a
     *                     negative value, the largest non-zero index observed in
     *                     the data will be used as the length.
     * @param store        the type of store to use for the data
     * @return a classification data set
     * @throws IOException if an error occurred reading the input stream
     */
    public static ClassificationDataSet loadC(Reader reader, double sparseRatio, int vectorLength) throws IOException {
        return loadC(reader, sparseRatio, vectorLength, DataStore.DEFAULT_STORE);
    }

    /**
     * Loads a new classification data set from a LIBSVM file, assuming the label is
     * a nominal target value
     * 
     * @param reader       the input stream for the file to load
     * @param sparseRatio  the fraction of non zero values to qualify a data point
     *                     as sparse
     * @param vectorLength the pre-determined length of each vector. If given a
     *                     negative value, the largest non-zero index observed in
     *                     the data will be used as the length.
     * @param store        the type of store to use for the data
     * @return a classification data set
     * @throws IOException if an error occurred reading the input stream
     */
    public static ClassificationDataSet loadC(Reader reader, double sparseRatio, int vectorLength, DataStore store) throws IOException {
        return (ClassificationDataSet) loadG(reader, sparseRatio, vectorLength, true, store);
    }

    /**
     * Generic loader for both Classification and Regression interpretations.
     * 
     * @param reader
     * @param sparseRatio
     * @param vectorLength
     * @param classification {@code true} to treat as classification, {@code false}
     *                       to treat as regression
     * @return
     * @throws IOException
     */
    private static DataSet loadG(Reader reader, double sparseRatio, int vectorLength, boolean classification, DataStore store) throws IOException {
        StringBuilder processBuffer = new StringBuilder(20);
        StringBuilder charBuffer = new StringBuilder(1024);
        char[] buffer = new char[1024];
        DataStore sparceVecs = store.emptyClone();
        sparceVecs.setCategoricalDataInfo(new CategoricalData[0]);
        /**
         * The category "label" for each value loaded in
         */
        DoubleArrayList labelVals = new DoubleArrayList();
        Map<Double, Integer> possibleCats = new HashMap<>();
        int maxLen = 1;

        STATE state = STATE.INITIAL;
        int position = 0;
        SparseVector tempVec = new SparseVector(1, 1);
        /**
         * The index that we have parse out of a non zero pair
         */
        int indexProcessing = -1;
        while (true) {

            while (charBuffer.length() - position <= 1)// make sure we have chars to handle
            {
                // move everything to the front
                charBuffer.delete(0, position);
                position = 0;

                int read = reader.read(buffer);
                if (read < 0)
                    break;
                charBuffer.append(buffer, 0, read);
            }

            if (charBuffer.length() - position == 0)// EOF, no more chars
            {
                if (state == STATE.LABEL)// last line was empty
                {
                    double label = Double.parseDouble(processBuffer.toString());

                    if (!possibleCats.containsKey(label) && classification)
                        possibleCats.put(label, possibleCats.size());
                    labelVals.add(label);

                    sparceVecs.addDataPoint(new DataPoint(new SparseVector(maxLen, 0)));
                } else if (state == STATE.WHITESPACE_AFTER_LABEL)// last line was empty, but we have already eaten the label
                {
                    sparceVecs.addDataPoint(new DataPoint(new SparseVector(maxLen, 0)));
                } else if (state == STATE.FEATURE_VALUE || state == STATE.WHITESPACE_AFTER_FEATURE)// line ended after a value pair
                {
                    // process the last value pair & insert into vec
                    double value = StringUtils.parseDouble(processBuffer, 0, processBuffer.length());
                    processBuffer.delete(0, processBuffer.length());

                    maxLen = Math.max(maxLen, indexProcessing + 1);
                    tempVec.setLength(maxLen);
                    if (value != 0)
                        tempVec.set(indexProcessing, value);
                    sparceVecs.addDataPoint(new DataPoint(tempVec.clone()));
                } else if (state == STATE.NEWLINE) {
                    // nothing to do and everything already processed, just return
                    break;
                } else
                    throw new RuntimeException();
                // we may have ended on a line, and have a sparse vec to add before returning
                break;
            }

            char ch = charBuffer.charAt(position);
            switch (state) {
            case INITIAL:
                state = STATE.LABEL;
                break;
            case LABEL:
                if (Character.isDigit(ch) || ch == '.' || ch == 'E' || ch == 'e' || ch == '-' || ch == '+') {
                    processBuffer.append(ch);
                    position++;
                } else if (Character.isWhitespace(ch))// this gets spaces and new lines
                {
                    double label = Double.parseDouble(processBuffer.toString());

                    if (!possibleCats.containsKey(label) && classification)
                        possibleCats.put(label, possibleCats.size());
                    labelVals.add(label);

                    // clean up and move to new state
                    processBuffer.delete(0, processBuffer.length());

                    if (ch == '\n' || ch == '\r')// empty line, so add a zero vector
                    {
                        tempVec.zeroOut();
                        sparceVecs.addDataPoint(new DataPoint(new SparseVector(maxLen, 0)));
                        state = STATE.NEWLINE;
                    } else// just white space
                    {
                        tempVec.zeroOut();
                        state = STATE.WHITESPACE_AFTER_LABEL;
                    }
                } else
                    throw new RuntimeException("Invalid LIBSVM file");
                break;
            case WHITESPACE_AFTER_LABEL:
                if (Character.isDigit(ch))// move to next state
                {
                    state = STATE.FEATURE_INDEX;
                } else if (Character.isWhitespace(ch)) {
                    if (ch == '\n' || ch == '\r') {
                        tempVec.zeroOut();
                        sparceVecs.addDataPoint(new DataPoint(new SparseVector(maxLen, 0)));/// no features again, add zero vec
                        state = STATE.NEWLINE;
                    } else// normal whie space
                        position++;
                } else
                    throw new RuntimeException();
                break;
            case FEATURE_INDEX:
                if (Character.isDigit(ch)) {
                    processBuffer.append(ch);
                    position++;
                } else if (ch == ':') {
                    indexProcessing = StringUtils.parseInt(processBuffer, 0, processBuffer.length()) - 1;
                    processBuffer.delete(0, processBuffer.length());

                    state = STATE.FEATURE_VALUE;
                    position++;
                } else
                    throw new RuntimeException();
                break;
            case FEATURE_VALUE:
                // we need to accept all the values that may be part of a float value
                if (Character.isDigit(ch) || ch == '.' || ch == 'E' || ch == 'e' || ch == '-' || ch == '+') {
                    processBuffer.append(ch);
                    position++;
                } else {
                    double value = StringUtils.parseDouble(processBuffer, 0, processBuffer.length());
                    processBuffer.delete(0, processBuffer.length());

                    maxLen = Math.max(maxLen, indexProcessing + 1);
                    tempVec.setLength(maxLen);
                    if (value != 0)
                        tempVec.set(indexProcessing, value);

                    if (Character.isWhitespace(ch))
                        state = STATE.WHITESPACE_AFTER_FEATURE;
                    else
                        throw new RuntimeException();
                }

                break;
            case WHITESPACE_AFTER_FEATURE:
                if (Character.isDigit(ch))
                    state = STATE.FEATURE_INDEX;
                else if (Character.isWhitespace(ch)) {
                    if (ch == '\n' || ch == '\r') {
                        sparceVecs.addDataPoint(new DataPoint(tempVec.clone()));
                        tempVec.zeroOut();
                        state = STATE.NEWLINE;
                    } else
                        position++;
                }
                break;
            case NEWLINE:
                if (ch == '\n' || ch == '\r')
                    position++;
                else {
                    state = STATE.LABEL;
                }
                break;
            }
        }

        if (vectorLength > 0)
            if (maxLen > vectorLength)
                throw new RuntimeException("Length given was " + vectorLength + ", but observed length was " + maxLen);
            else
                maxLen = vectorLength;

        if (classification) {
            CategoricalData predicting = new CategoricalData(possibleCats.size());

            // Give categories a unique ordering to avoid loading issues based on the order
            // categories are presented
            DoubleArrayList allCatKeys = new DoubleArrayList(possibleCats.keySet());
            Collections.sort(allCatKeys);
            for (int i = 0; i < allCatKeys.size(); i++)
                possibleCats.put(allCatKeys.getDouble(i), i);
            // apply to target values now

            IntArrayList label_targets = IntArrayList.wrap(labelVals.stream().mapToInt(possibleCats::get).toArray());

            sparceVecs.setNumNumeric(maxLen);
            sparceVecs.finishAdding();
            ClassificationDataSet cds = new ClassificationDataSet(sparceVecs, label_targets);

            if (store instanceof RowMajorStore)
                cds.applyTransform(new DenseSparceTransform(sparseRatio));

            return cds;
        } else// regression
        {
            sparceVecs.setNumNumeric(maxLen);
            sparceVecs.finishAdding();
            RegressionDataSet rds = new RegressionDataSet(sparceVecs, labelVals);
            rds.applyTransform(new DenseSparceTransform(sparseRatio));

            return rds;
        }
    }

    /**
     * Writes out the given classification data set as a LIBSVM data file
     * 
     * @param data the data set to write to a file
     * @param os   the output stream to write to. The stream will not be closed or
     *             flushed by this method
     */
    public static void write(ClassificationDataSet data, OutputStream os) {
        PrintWriter writer = new PrintWriter(os);
        for (int i = 0; i < data.size(); i++) {
            int pred = data.getDataPointCategory(i);
            Vec vals = data.getDataPoint(i).getNumericalValues();
            writer.write(pred + " ");
            for (IndexValue iv : vals) {
                double val = iv.getValue();
                if (Math.rint(val) == val)// cast to long before writting to save space
                    writer.write((iv.getIndex() + 1) + ":" + (long) val + " ");// +1 b/c 1 based indexing
                else
                    writer.write((iv.getIndex() + 1) + ":" + val + " ");// +1 b/c 1 based indexing
            }
            writer.write("\n");
        }
        writer.flush();
        writer.close();
    }

    /**
     * Writes out the given regression data set as a LIBSVM data file
     * 
     * @param data the data set to write to a file
     * @param os   the output stream to write to. The stream will not be closed or
     *             flushed by this method
     */
    public static void write(RegressionDataSet data, OutputStream os) {
        PrintWriter writer = new PrintWriter(os);
        for (int i = 0; i < data.size(); i++) {
            double pred = data.getTargetValue(i);
            Vec vals = data.getDataPoint(i).getNumericalValues();
            writer.write(pred + " ");
            for (IndexValue iv : vals) {
                double val = iv.getValue();
                if (Math.rint(val) == val)// cast to long before writting to save space
                    writer.write((iv.getIndex() + 1) + ":" + (long) val + " ");// +1 b/c 1 based indexing
                else
                    writer.write((iv.getIndex() + 1) + ":" + val + " ");// +1 b/c 1 based indexing
            }
            writer.write("\n");
        }
        writer.flush();
        writer.close();
    }

    /**
     * Returns a DataWriter object which can be used to stream a set of arbitrary
     * datapoints into the given output stream. This works in a thread safe
     * manner.<br>
     * Categorical information dose not need to be specified since LIBSVM files
     * can't store categorical features.
     *
     * @param out  the location to store all the data
     * @param dim  information on how many numeric features exist
     * @param type what type of data set (simple, classification, regression) to be
     *             written
     * @return the DataWriter that the actual points can be streamed through
     * @throws IOException
     */
    public static DataWriter getWriter(OutputStream out, int dim, DataWriter.DataSetType type) throws IOException {
        DataWriter dw = new DataWriter(out, new CategoricalData[0], dim, type) {
            @Override
            protected void writeHeader(CategoricalData[] catInfo, int dim, DataWriter.DataSetType type, OutputStream out) {
                // nothing to do, LIBSVM format has no header
            }

            @Override
            protected void pointToBytes(double weight, DataPoint dp, double label, ByteArrayOutputStream byteOut) {
                PrintWriter writer = new PrintWriter(byteOut);

                // write out label
                if (this.type == DataSetType.REGRESSION)
                    writer.write(label + " ");
                else if (this.type == DataSetType.CLASSIFICATION)
                    writer.write((int) label + " ");
                else if (this.type == DataSetType.SIMPLE)
                    writer.write("0 ");

                Vec vals = dp.getNumericalValues();
                for (IndexValue iv : vals) {
                    double val = iv.getValue();
                    if (Math.rint(val) == val)// cast to long before writting to save space
                        writer.write((iv.getIndex() + 1) + ":" + (long) val + " ");// +1 b/c 1 based indexing
                    else
                        writer.write((iv.getIndex() + 1) + ":" + val + " ");// +1 b/c 1 based indexing
                }
                writer.write("\n");
                writer.flush();
            }
        };

        return dw;
    }

    /**
     * Simple state machine used to parse LIBSVM files
     */
    private enum STATE {
        /**
         * Initial state, doesn't actually do anything
         */
        INITIAL, LABEL, WHITESPACE_AFTER_LABEL, FEATURE_INDEX, FEATURE_VALUE, WHITESPACE_AFTER_FEATURE, NEWLINE,
    }

}
